{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "499ea865-b8bc-4153-a737-cc8a52c838c4",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ourownstory/neural_prophet/blob/master/tutorials/feature-use/collect_predictions.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prediction Collection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bb3ac95-e660-4f36-bd03-5bb3ab68aefb",
   "metadata": {},
   "source": [
    "## Collect Predictions\n",
    "First, let's fit a vanilla model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "dc722653-e984-4044-bff4-589311e45b4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'google.colab' in str(get_ipython()):\n",
    "    !pip install git+https://github.com/ourownstory/neural_prophet.git # may take a while\n",
    "    #!pip install neuralprophet # much faster, but may not have the latest upgrades/bugfixes\n",
    "    \n",
    "import pandas as pd\n",
    "from neuralprophet import NeuralProphet, set_log_level\n",
    "set_log_level(\"ERROR\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "863bf403-5810-491e-922c-9237d81673d7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ds</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>141</th>\n",
       "      <td>1960-10-01</td>\n",
       "      <td>461</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>142</th>\n",
       "      <td>1960-11-01</td>\n",
       "      <td>390</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>143</th>\n",
       "      <td>1960-12-01</td>\n",
       "      <td>432</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             ds    y\n",
       "141  1960-10-01  461\n",
       "142  1960-11-01  390\n",
       "143  1960-12-01  432"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_location = \"https://raw.githubusercontent.com/ourownstory/neuralprophet-data/main/datasets/\"\n",
    "df = pd.read_csv(data_location + \"air_passengers.csv\")\n",
    "df.tail(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "db4f3490-2652-4fd5-8b27-bd6d6b88bb67",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8953df6913da40b9b5c4630716f27ccb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/208 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                     \r"
     ]
    }
   ],
   "source": [
    "m = NeuralProphet(n_lags=5, n_forecasts=3)\n",
    "metrics_train = m.fit(df=df, freq=\"MS\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d621caf-b601-43ca-a64b-a1d7be7453c0",
   "metadata": {},
   "source": [
    "## Collect in-sample predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d2dc84d-c388-4b59-8fe3-89633e9671fe",
   "metadata": {},
   "source": [
    "## Predictions sorted based on forecast target\n",
    "Calling `predict`, we get a `df_forecast` where each `'yhat<i>'` refers to the `<i>` -step-ahead prediction for **this row's datetime being the target**.\n",
    "Here, `<i>`  refers to the age of the prediction.\n",
    "\n",
    "e.g. `yhat3` is the prediction for this datetime, predicted 3 steps ago, it is \"3 steps old\".\n",
    "\n",
    "Note that the last row `1961-3-01` only has a `yhat3`, which was forecasted at the last location with data `1960-12-01`.\n",
    "Because we lack inputs after that location, we do not have more recent predictions `yhat1` from `1961-2-01` nor  `yhat2` from `1961-1-01`.\n",
    "\n",
    "We also get the individual forecast components, which also refer to their respective contrigution to `yhat<i>`, forecasted `<i>` steps ago. \n",
    "\n",
    "Components without an added number are only time-dependent or based on future regressors, neither are lagged, and thus a single value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7a1b8a10-9708-45b9-9f68-1a4be0dd1ee9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ds</th>\n",
       "      <th>y</th>\n",
       "      <th>yhat1</th>\n",
       "      <th>residual1</th>\n",
       "      <th>yhat2</th>\n",
       "      <th>residual2</th>\n",
       "      <th>yhat3</th>\n",
       "      <th>residual3</th>\n",
       "      <th>ar1</th>\n",
       "      <th>ar2</th>\n",
       "      <th>ar3</th>\n",
       "      <th>trend</th>\n",
       "      <th>season_yearly</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>141</th>\n",
       "      <td>1960-10-01</td>\n",
       "      <td>461.0</td>\n",
       "      <td>465.914337</td>\n",
       "      <td>4.914337</td>\n",
       "      <td>471.379517</td>\n",
       "      <td>10.379517</td>\n",
       "      <td>478.984253</td>\n",
       "      <td>17.984253</td>\n",
       "      <td>-198.053421</td>\n",
       "      <td>-192.588257</td>\n",
       "      <td>-184.98349</td>\n",
       "      <td>683.929749</td>\n",
       "      <td>-19.961983</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>142</th>\n",
       "      <td>1960-11-01</td>\n",
       "      <td>390.0</td>\n",
       "      <td>409.738464</td>\n",
       "      <td>19.738464</td>\n",
       "      <td>410.88266</td>\n",
       "      <td>20.88266</td>\n",
       "      <td>422.372314</td>\n",
       "      <td>32.372314</td>\n",
       "      <td>-246.174759</td>\n",
       "      <td>-245.030563</td>\n",
       "      <td>-233.540924</td>\n",
       "      <td>690.632263</td>\n",
       "      <td>-34.719048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>143</th>\n",
       "      <td>1960-12-01</td>\n",
       "      <td>432.0</td>\n",
       "      <td>421.198425</td>\n",
       "      <td>-10.801575</td>\n",
       "      <td>440.99115</td>\n",
       "      <td>8.99115</td>\n",
       "      <td>441.647461</td>\n",
       "      <td>9.647461</td>\n",
       "      <td>-287.63382</td>\n",
       "      <td>-267.841095</td>\n",
       "      <td>-267.184784</td>\n",
       "      <td>697.118591</td>\n",
       "      <td>11.713625</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            ds      y       yhat1  residual1       yhat2  residual2  \\\n",
       "141 1960-10-01  461.0  465.914337   4.914337  471.379517  10.379517   \n",
       "142 1960-11-01  390.0  409.738464  19.738464   410.88266   20.88266   \n",
       "143 1960-12-01  432.0  421.198425 -10.801575   440.99115    8.99115   \n",
       "\n",
       "          yhat3  residual3         ar1         ar2         ar3       trend  \\\n",
       "141  478.984253  17.984253 -198.053421 -192.588257  -184.98349  683.929749   \n",
       "142  422.372314  32.372314 -246.174759 -245.030563 -233.540924  690.632263   \n",
       "143  441.647461   9.647461  -287.63382 -267.841095 -267.184784  697.118591   \n",
       "\n",
       "    season_yearly  \n",
       "141    -19.961983  \n",
       "142    -34.719048  \n",
       "143     11.713625  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(data_location + \"air_passengers.csv\")\n",
    "forecast = m.predict(df)\n",
    "forecast.tail(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a9c7316-95fd-420e-ac51-1fdf34bd5c51",
   "metadata": {},
   "source": [
    "## Predictions based on forecast start\n",
    "Calling `predict_raw`, we get a `df` where each `'step<i>'` refers to the `<i>`th step-ahead prediction **starting at this row's datetime**.\n",
    "Here, `<i>`  refers to how many steps ahead the prediction is targeted at.\n",
    "\n",
    "e.g. `step0` is the prediction for this datetime. `step1` is the prediction for the next datetime. \n",
    "\n",
    "All the predictions of a particular row were made at the same time: One step before the rows datestamp."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "39e3486e-c86d-4952-a8cd-339817b905fa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ds</th>\n",
       "      <th>step0</th>\n",
       "      <th>step1</th>\n",
       "      <th>step2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>136</th>\n",
       "      <td>1960-10-01</td>\n",
       "      <td>465.914337</td>\n",
       "      <td>410.882660</td>\n",
       "      <td>441.647461</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>137</th>\n",
       "      <td>1960-11-01</td>\n",
       "      <td>409.738464</td>\n",
       "      <td>440.991150</td>\n",
       "      <td>458.693176</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>138</th>\n",
       "      <td>1960-12-01</td>\n",
       "      <td>421.198425</td>\n",
       "      <td>443.388397</td>\n",
       "      <td>456.959534</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            ds       step0       step1       step2\n",
       "136 1960-10-01  465.914337  410.882660  441.647461\n",
       "137 1960-11-01  409.738464  440.991150  458.693176\n",
       "138 1960-12-01  421.198425  443.388397  456.959534"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(data_location + \"air_passengers.csv\")\n",
    "forecast = m.predict(df, decompose=False, raw=True)\n",
    "forecast.tail(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "953b82fd-3e5a-445c-b276-89607281d5f7",
   "metadata": {},
   "source": [
    "Note that the last row contains the last possible forecast, forecasting `1961-1-01`, `1961-2-01` and `1961-3-01` with data available at `1960-12-01`.\n",
    "\n",
    "\n",
    "Setting `decompose=True` will include the individual forecast components, which also refer to their respective contrigution to `step<i>` into the future. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f86f16de-9535-4573-8c09-8033b7f27ee3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ds</th>\n",
       "      <th>step0</th>\n",
       "      <th>step1</th>\n",
       "      <th>step2</th>\n",
       "      <th>trend0</th>\n",
       "      <th>trend1</th>\n",
       "      <th>trend2</th>\n",
       "      <th>season_yearly0</th>\n",
       "      <th>season_yearly1</th>\n",
       "      <th>season_yearly2</th>\n",
       "      <th>ar0</th>\n",
       "      <th>ar1</th>\n",
       "      <th>ar2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>136</th>\n",
       "      <td>1960-10-01</td>\n",
       "      <td>465.914337</td>\n",
       "      <td>410.882660</td>\n",
       "      <td>441.647461</td>\n",
       "      <td>683.929749</td>\n",
       "      <td>690.632263</td>\n",
       "      <td>697.118591</td>\n",
       "      <td>-19.961983</td>\n",
       "      <td>-34.719048</td>\n",
       "      <td>11.713625</td>\n",
       "      <td>-198.053421</td>\n",
       "      <td>-245.030563</td>\n",
       "      <td>-267.184784</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>137</th>\n",
       "      <td>1960-11-01</td>\n",
       "      <td>409.738464</td>\n",
       "      <td>440.991150</td>\n",
       "      <td>458.693176</td>\n",
       "      <td>690.632263</td>\n",
       "      <td>697.118591</td>\n",
       "      <td>703.821167</td>\n",
       "      <td>-34.719048</td>\n",
       "      <td>11.713625</td>\n",
       "      <td>3.806945</td>\n",
       "      <td>-246.174759</td>\n",
       "      <td>-267.841095</td>\n",
       "      <td>-248.934937</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>138</th>\n",
       "      <td>1960-12-01</td>\n",
       "      <td>421.198425</td>\n",
       "      <td>443.388397</td>\n",
       "      <td>456.959534</td>\n",
       "      <td>697.118591</td>\n",
       "      <td>703.821167</td>\n",
       "      <td>710.523743</td>\n",
       "      <td>11.713625</td>\n",
       "      <td>3.806945</td>\n",
       "      <td>-24.743301</td>\n",
       "      <td>-287.633820</td>\n",
       "      <td>-264.239685</td>\n",
       "      <td>-228.820923</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            ds       step0       step1       step2      trend0      trend1  \\\n",
       "136 1960-10-01  465.914337  410.882660  441.647461  683.929749  690.632263   \n",
       "137 1960-11-01  409.738464  440.991150  458.693176  690.632263  697.118591   \n",
       "138 1960-12-01  421.198425  443.388397  456.959534  697.118591  703.821167   \n",
       "\n",
       "         trend2  season_yearly0  season_yearly1  season_yearly2         ar0  \\\n",
       "136  697.118591      -19.961983      -34.719048       11.713625 -198.053421   \n",
       "137  703.821167      -34.719048       11.713625        3.806945 -246.174759   \n",
       "138  710.523743       11.713625        3.806945      -24.743301 -287.633820   \n",
       "\n",
       "            ar1         ar2  \n",
       "136 -245.030563 -267.184784  \n",
       "137 -267.841095 -248.934937  \n",
       "138 -264.239685 -228.820923  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(data_location + \"air_passengers.csv\")\n",
    "forecast = m.predict(df, decompose=True, raw=True)\n",
    "forecast.tail(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1d12b4a-13a5-4eea-9a41-da1337bc7b99",
   "metadata": {},
   "source": [
    "## Collect out-of-sample predictions\n",
    "This is how you can extend predictions into the unknown future:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "276f79da-e024-4c12-8a68-dc59885a3e92",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(data_location + \"air_passengers.csv\")\n",
    "future = m.make_future_dataframe(df, periods=3) # periods=m.n_forecasts, n_historic_predictions=False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ab68b04-a6b8-45d1-a9bd-8c4d33d16a40",
   "metadata": {},
   "source": [
    "Now, the forecast dataframe only contains predictions about the yet unobserved future."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0bdfbbbc-730f-470c-ae86-f1e2d2ac7b56",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ds</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1960-11-01</td>\n",
       "      <td>390</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1960-12-01</td>\n",
       "      <td>432</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1961-01-01</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1961-02-01</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1961-03-01</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          ds     y\n",
       "3 1960-11-01   390\n",
       "4 1960-12-01   432\n",
       "5 1961-01-01  None\n",
       "6 1961-02-01  None\n",
       "7 1961-03-01  None"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "future.tail()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6cbb621e-4abd-406b-8c16-23361d9aac15",
   "metadata": {},
   "source": [
    "## Predictions based on forecast target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "97275377-ca98-417a-9234-897f48cdfd57",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ds</th>\n",
       "      <th>y</th>\n",
       "      <th>yhat1</th>\n",
       "      <th>residual1</th>\n",
       "      <th>yhat2</th>\n",
       "      <th>residual2</th>\n",
       "      <th>yhat3</th>\n",
       "      <th>residual3</th>\n",
       "      <th>ar1</th>\n",
       "      <th>ar2</th>\n",
       "      <th>ar3</th>\n",
       "      <th>trend</th>\n",
       "      <th>season_yearly</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1961-01-01</td>\n",
       "      <td>NaN</td>\n",
       "      <td>451.707611</td>\n",
       "      <td>NaN</td>\n",
       "      <td>None</td>\n",
       "      <td>NaN</td>\n",
       "      <td>None</td>\n",
       "      <td>NaN</td>\n",
       "      <td>-255.920502</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>703.821167</td>\n",
       "      <td>3.806945</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1961-02-01</td>\n",
       "      <td>NaN</td>\n",
       "      <td>None</td>\n",
       "      <td>NaN</td>\n",
       "      <td>465.932037</td>\n",
       "      <td>NaN</td>\n",
       "      <td>None</td>\n",
       "      <td>NaN</td>\n",
       "      <td>None</td>\n",
       "      <td>-219.848434</td>\n",
       "      <td>None</td>\n",
       "      <td>710.523743</td>\n",
       "      <td>-24.743301</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1961-03-01</td>\n",
       "      <td>NaN</td>\n",
       "      <td>None</td>\n",
       "      <td>NaN</td>\n",
       "      <td>None</td>\n",
       "      <td>NaN</td>\n",
       "      <td>525.330139</td>\n",
       "      <td>NaN</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>-174.258484</td>\n",
       "      <td>716.577637</td>\n",
       "      <td>-16.989017</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          ds   y       yhat1 residual1       yhat2 residual2       yhat3  \\\n",
       "5 1961-01-01 NaN  451.707611       NaN        None       NaN        None   \n",
       "6 1961-02-01 NaN        None       NaN  465.932037       NaN        None   \n",
       "7 1961-03-01 NaN        None       NaN        None       NaN  525.330139   \n",
       "\n",
       "  residual3         ar1         ar2         ar3       trend season_yearly  \n",
       "5       NaN -255.920502        None        None  703.821167      3.806945  \n",
       "6       NaN        None -219.848434        None  710.523743    -24.743301  \n",
       "7       NaN        None        None -174.258484  716.577637    -16.989017  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "forecast = m.predict(future)\n",
    "forecast.tail(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccd0af24-2aa1-45dc-884f-f27fd71818a2",
   "metadata": {},
   "source": [
    "## Predictions based on forecast start\n",
    "We can also get the forecasts based on the forecast start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d7fa8af1-c921-4e3c-96e8-d004a1adc751",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ds</th>\n",
       "      <th>step0</th>\n",
       "      <th>step1</th>\n",
       "      <th>step2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1961-01-01</td>\n",
       "      <td>451.707611</td>\n",
       "      <td>465.932037</td>\n",
       "      <td>525.330139</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          ds       step0       step1       step2\n",
       "0 1961-01-01  451.707611  465.932037  525.330139"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "forecast = m.predict(future, raw=True, decompose=False)\n",
    "forecast"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83584895-99c0-4c90-baca-73e946258f3f",
   "metadata": {},
   "source": [
    "### Advanced: Get predictions based on forecast start as arrays\n",
    "This function was not meant to be used directly, but if you have a specific need, it may be useful to get the values directly as arrays.\n",
    "The returned predictions are also based on forecast origin."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56a7ff91-562d-4bfd-a2ba-17de298b4606",
   "metadata": {},
   "source": [
    "... and as an array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "777932c7-09a0-4f86-9cc2-0485f78f96f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = m._prepare_dataframe_to_predict(future)\n",
    "dates, predicted, components = m._predict_raw(df, include_components=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ae18bd3a-72c2-4361-a2cc-689c47936134",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5   1961-01-01\n",
       "Name: ds, dtype: datetime64[ns]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dates[-3:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5f438921-168a-4d09-86bd-2e8dc3e98f22",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[451.7076 , 465.93204, 525.33014]], dtype=float32)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted[-3:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f4686019-8954-495e-a387-cfdabd1e52e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('trend', array([[703.82117, 710.52374, 716.57764]], dtype=float32)),\n",
       " ('season_yearly',\n",
       "  array([[  3.806945, -24.743301, -16.989017]], dtype=float32)),\n",
       " ('ar', array([[-255.9205 , -219.84843, -174.25848]], dtype=float32))]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[(key, values[-3:]) for key, values in components.items()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb1763f3-8bf2-4d29-a19d-25b46dee9413",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "np-dev",
   "language": "python",
   "name": "np-dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
